# 评估指标定义
# 包含各种评估指标的计算函数
import evaluate
import numpy as np


def get_compute_metrics_fn(task='classification'):
    """获取评估指标计算函数"""
    if task == 'classification':
        accuracy = evaluate.load('accuracy')
        def compute_metrics(eval_pred):
            predictions, labels = eval_pred
            predictions = np.argmax(predictions, axis=1)
            return accuracy.compute(predictions=predictions, references=labels)
        return compute_metrics
    # 可以添加其他任务的评估指标
    return None